{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import re\n",
    "import time\n",
    "import collections\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len from: 500, len to: 500\n"
     ]
    }
   ],
   "source": [
    "with open('english-train', 'r') as fopen:\n",
    "    text_from = fopen.read().lower().split('\\n')[:-1]\n",
    "with open('vietnam-train', 'r') as fopen:\n",
    "    text_to = fopen.read().lower().split('\\n')[:-1]\n",
    "print('len from: %d, len to: %d'%(len(text_from), len(text_to)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 1935\n",
      "Most common words [(',', 564), ('.', 477), ('the', 368), ('and', 286), ('to', 242), ('of', 220)]\n",
      "Sample data [482, 483, 78, 6, 137, 484, 10, 226, 787, 14] ['rachel', 'pike', ':', 'the', 'science', 'behind', 'a', 'climate', 'headline', 'in']\n"
     ]
    }
   ],
   "source": [
    "concat_from = ' '.join(text_from).split()\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab to size: 1461\n",
      "Most common words [(',', 472), ('.', 430), ('tôi', 283), ('và', 230), ('có', 199), ('chúng', 196)]\n",
      "Sample data [84, 22, 668, 73, 10, 389, 110, 34, 81, 299] ['khoa', 'học', 'đằng', 'sau', 'một', 'tiêu', 'đề', 'về', 'khí', 'hậu']\n"
     ]
    }
   ],
   "source": [
    "concat_to = ' '.join(text_to).split()\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab to size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(text_to)):\n",
    "    text_to[i] += ' EOS'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k,UNK))\n",
    "        X.append(ints)\n",
    "    return X\n",
    "\n",
    "def pad_sentence_batch(sentence_batch, sentence_batch_y, pad_int):\n",
    "    x, y = [], []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    max_sentence_len_y = max([len(sentence) for sentence in sentence_batch_y])\n",
    "    max_sentence_len = max(max_sentence_len, max_sentence_len_y)\n",
    "    for no, sentence in enumerate(sentence_batch):\n",
    "        x.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        y.append(sentence_batch_y[no] + [pad_int] * (max_sentence_len - len(sentence_batch_y[no])))\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(text_from, dictionary_from)\n",
    "Y = str_idx(text_to, dictionary_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "maxlen_question = max([len(x) for x in X]) * 2\n",
    "maxlen_answer = max([len(y) for y in Y]) * 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_normalization(inputs, block_name, reuse, epsilon=1e-8):\n",
    "    with tf.variable_scope(block_name, reuse = reuse):\n",
    "        mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "        normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "\n",
    "        params_shape = inputs.get_shape()[-1:]\n",
    "        gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "        beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "\n",
    "        outputs = gamma * normalized + beta\n",
    "        return outputs\n",
    "\n",
    "def conv1d(input_, output_channels, block_name, reuse, dilation = 1, filter_width = 1, causal = False):\n",
    "    with tf.variable_scope(block_name, reuse = reuse):\n",
    "        w = tf.get_variable('w',  [1, filter_width, int(input_.get_shape()[-1]), output_channels],\n",
    "                            tf.float32, tf.initializers.random_normal(stddev = 0.02))\n",
    "        b = tf.get_variable('b',  [output_channels],\n",
    "                            tf.float32, tf.zeros_initializer())\n",
    "        if causal:\n",
    "            padding = [[0, 0], [(filter_width - 1) * dilation, 0], [0, 0]]\n",
    "            padded = tf.pad(input_, padding)\n",
    "            input_expanded = tf.expand_dims(padded, dim = 1)\n",
    "            out = tf.nn.atrous_conv2d(input_expanded, w, rate = dilation, padding = 'VALID') + b\n",
    "        else:\n",
    "            input_expanded = tf.expand_dims(input_, dim = 1)\n",
    "            out = tf.nn.atrous_conv2d(input_expanded, w, rate = dilation, padding = 'SAME') + b\n",
    "        return tf.squeeze(out, [1])\n",
    "\n",
    "def bytenet_residual_block(input_, dilation, layer_no, \n",
    "                            residual_channels, filter_width, block_type,\n",
    "                            causal = True, reuse = False):\n",
    "    block_name = \"bytenet_{}_layer_{}_{}\".format(block_type, layer_no, dilation)\n",
    "    print(block_name)\n",
    "    with tf.variable_scope(block_name, reuse = reuse):\n",
    "        relu1 = tf.nn.relu(layer_normalization(input_, block_name + '_0', reuse))\n",
    "        conv1 = conv1d(relu1, residual_channels, block_name + '_0', reuse)\n",
    "        relu2 = tf.nn.relu(layer_normalization(conv1, block_name + '_1', reuse))\n",
    "        dilated_conv = conv1d(relu2, residual_channels,\n",
    "                              block_name + '_1', reuse,\n",
    "                              dilation, filter_width,\n",
    "                              causal = causal)\n",
    "        print(dilated_conv)\n",
    "        relu3 = tf.nn.relu(layer_normalization(dilated_conv, block_name + '_2', reuse))\n",
    "        conv2 = conv1d(relu3, 2 * residual_channels, block_name + '_2', reuse)\n",
    "        return input_ + conv2\n",
    "    \n",
    "class ByteNet:\n",
    "    def __init__(self, from_vocab_size, to_vocab_size, channels, encoder_dilations,\n",
    "                decoder_dilations, encoder_filter_width, decoder_filter_width,\n",
    "                learning_rate = 0.001, beta1=0.5):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        target_1 = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        embedding_channels = 2 * channels\n",
    "        max_seq = tf.maximum(tf.reduce_max(self.Y_seq_len), tf.reduce_max(self.X_seq_len))\n",
    "        w_source_embedding = tf.Variable(tf.random_normal([from_vocab_size, \n",
    "                                                           embedding_channels], stddev = 0.02))\n",
    "        w_target_embedding = tf.Variable(tf.random_normal([to_vocab_size, \n",
    "                                                           embedding_channels], stddev = 0.02))\n",
    "        \n",
    "        def forward(x, y, reuse = False):\n",
    "            source_embedding = tf.nn.embedding_lookup(w_source_embedding, x)\n",
    "            target_1_embedding = tf.nn.embedding_lookup(w_target_embedding, y)\n",
    "        \n",
    "        \n",
    "            curr_input = source_embedding\n",
    "            for layer_no, dilation in enumerate(encoder_dilations):\n",
    "                curr_input = bytenet_residual_block(curr_input, dilation, \n",
    "                                                    layer_no, channels, \n",
    "                                                    encoder_filter_width,\n",
    "                                                    'encoder',\n",
    "                                                    causal = False, reuse = reuse)\n",
    "            encoder_output = curr_input\n",
    "            combined_embedding = target_1_embedding + encoder_output\n",
    "            curr_input = combined_embedding\n",
    "            for layer_no, dilation in enumerate(decoder_dilations):\n",
    "                curr_input = bytenet_residual_block(curr_input, dilation, \n",
    "                                                    layer_no, channels, \n",
    "                                                    encoder_filter_width, \n",
    "                                                    'decoder',\n",
    "                                                    causal = False, reuse = reuse)\n",
    "            with tf.variable_scope('logits', reuse = reuse):\n",
    "                return conv1d(curr_input, to_vocab_size, 'logits', reuse)\n",
    "        \n",
    "        self.logits = forward(self.X, target_1)\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, max_seq, dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "        def cond(i, y, temp):\n",
    "            return i < tf.reduce_max(max_seq)\n",
    "        \n",
    "        def body(i, y, temp):\n",
    "            logits = forward(self.X, y, reuse = True)\n",
    "            ids = tf.argmax(logits, -1)[:, i]\n",
    "            ids = tf.expand_dims(ids, -1)\n",
    "            temp = tf.concat([temp[:, 1:], ids], -1)\n",
    "            y = tf.concat([temp[:, -(i+1):], temp[:, :-(i+1)]], -1)\n",
    "            y = tf.reshape(y, [tf.shape(temp)[0], max_seq])\n",
    "            i += 1\n",
    "            return i, y, temp\n",
    "        \n",
    "        target = tf.fill([batch_size, max_seq], GO)\n",
    "        target = tf.cast(target, tf.int64)\n",
    "        self.target = target\n",
    "        \n",
    "        _, self.predicting_ids, _ = tf.while_loop(cond, body, \n",
    "                                                  [tf.constant(0), target, target])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "residual_channels = 128\n",
    "encoder_dilations = [1,2,4,8,16,1,2,4,8,16]\n",
    "decoder_dilations = [1,2,4,8,16,1,2,4,8,16]\n",
    "encoder_filter_width = 3\n",
    "decoder_filter_width = 3\n",
    "batch_size = 16\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bytenet_encoder_layer_0_1\n",
      "WARNING:tensorflow:From <ipython-input-11-1805b9e945ed>:25: calling expand_dims (from tensorflow.python.ops.array_ops) with dim is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the `axis` argument instead\n",
      "Tensor(\"bytenet_encoder_layer_0_1/bytenet_encoder_layer_0_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_1_2\n",
      "Tensor(\"bytenet_encoder_layer_1_2/bytenet_encoder_layer_1_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_2_4\n",
      "Tensor(\"bytenet_encoder_layer_2_4/bytenet_encoder_layer_2_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_3_8\n",
      "Tensor(\"bytenet_encoder_layer_3_8/bytenet_encoder_layer_3_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_4_16\n",
      "Tensor(\"bytenet_encoder_layer_4_16/bytenet_encoder_layer_4_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_5_1\n",
      "Tensor(\"bytenet_encoder_layer_5_1/bytenet_encoder_layer_5_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_6_2\n",
      "Tensor(\"bytenet_encoder_layer_6_2/bytenet_encoder_layer_6_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_7_4\n",
      "Tensor(\"bytenet_encoder_layer_7_4/bytenet_encoder_layer_7_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_8_8\n",
      "Tensor(\"bytenet_encoder_layer_8_8/bytenet_encoder_layer_8_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_9_16\n",
      "Tensor(\"bytenet_encoder_layer_9_16/bytenet_encoder_layer_9_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_0_1\n",
      "Tensor(\"bytenet_decoder_layer_0_1/bytenet_decoder_layer_0_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_1_2\n",
      "Tensor(\"bytenet_decoder_layer_1_2/bytenet_decoder_layer_1_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_2_4\n",
      "Tensor(\"bytenet_decoder_layer_2_4/bytenet_decoder_layer_2_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_3_8\n",
      "Tensor(\"bytenet_decoder_layer_3_8/bytenet_decoder_layer_3_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_4_16\n",
      "Tensor(\"bytenet_decoder_layer_4_16/bytenet_decoder_layer_4_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_5_1\n",
      "Tensor(\"bytenet_decoder_layer_5_1/bytenet_decoder_layer_5_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_6_2\n",
      "Tensor(\"bytenet_decoder_layer_6_2/bytenet_decoder_layer_6_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_7_4\n",
      "Tensor(\"bytenet_decoder_layer_7_4/bytenet_decoder_layer_7_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_8_8\n",
      "Tensor(\"bytenet_decoder_layer_8_8/bytenet_decoder_layer_8_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_9_16\n",
      "Tensor(\"bytenet_decoder_layer_9_16/bytenet_decoder_layer_9_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_0_1\n",
      "Tensor(\"while/bytenet_encoder_layer_0_1/bytenet_encoder_layer_0_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_1_2\n",
      "Tensor(\"while/bytenet_encoder_layer_1_2/bytenet_encoder_layer_1_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_2_4\n",
      "Tensor(\"while/bytenet_encoder_layer_2_4/bytenet_encoder_layer_2_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_3_8\n",
      "Tensor(\"while/bytenet_encoder_layer_3_8/bytenet_encoder_layer_3_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_4_16\n",
      "Tensor(\"while/bytenet_encoder_layer_4_16/bytenet_encoder_layer_4_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_5_1\n",
      "Tensor(\"while/bytenet_encoder_layer_5_1/bytenet_encoder_layer_5_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_6_2\n",
      "Tensor(\"while/bytenet_encoder_layer_6_2/bytenet_encoder_layer_6_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_7_4\n",
      "Tensor(\"while/bytenet_encoder_layer_7_4/bytenet_encoder_layer_7_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_8_8\n",
      "Tensor(\"while/bytenet_encoder_layer_8_8/bytenet_encoder_layer_8_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_encoder_layer_9_16\n",
      "Tensor(\"while/bytenet_encoder_layer_9_16/bytenet_encoder_layer_9_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_0_1\n",
      "Tensor(\"while/bytenet_decoder_layer_0_1/bytenet_decoder_layer_0_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_1_2\n",
      "Tensor(\"while/bytenet_decoder_layer_1_2/bytenet_decoder_layer_1_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_2_4\n",
      "Tensor(\"while/bytenet_decoder_layer_2_4/bytenet_decoder_layer_2_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_3_8\n",
      "Tensor(\"while/bytenet_decoder_layer_3_8/bytenet_decoder_layer_3_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_4_16\n",
      "Tensor(\"while/bytenet_decoder_layer_4_16/bytenet_decoder_layer_4_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_5_1\n",
      "Tensor(\"while/bytenet_decoder_layer_5_1/bytenet_decoder_layer_5_1_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_6_2\n",
      "Tensor(\"while/bytenet_decoder_layer_6_2/bytenet_decoder_layer_6_2_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_7_4\n",
      "Tensor(\"while/bytenet_decoder_layer_7_4/bytenet_decoder_layer_7_4_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_8_8\n",
      "Tensor(\"while/bytenet_decoder_layer_8_8/bytenet_decoder_layer_8_8_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n",
      "bytenet_decoder_layer_9_16\n",
      "Tensor(\"while/bytenet_decoder_layer_9_16/bytenet_decoder_layer_9_16_1_1/Squeeze:0\", shape=(?, ?, 128), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = ByteNet(len(dictionary_from), len(dictionary_to), \n",
    "                residual_channels, encoder_dilations, decoder_dilations,\n",
    "                encoder_filter_width,decoder_filter_width)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 6.328035, avg accuracy: 0.072276\n",
      "epoch: 2, avg loss: 5.965444, avg accuracy: 0.094979\n",
      "epoch: 3, avg loss: 5.767229, avg accuracy: 0.105403\n",
      "epoch: 4, avg loss: 5.496558, avg accuracy: 0.143506\n",
      "epoch: 5, avg loss: 5.061917, avg accuracy: 0.187899\n",
      "epoch: 6, avg loss: 4.610906, avg accuracy: 0.227049\n",
      "epoch: 7, avg loss: 4.046287, avg accuracy: 0.279907\n",
      "epoch: 8, avg loss: 3.496574, avg accuracy: 0.345208\n",
      "epoch: 9, avg loss: 2.976003, avg accuracy: 0.413868\n",
      "epoch: 10, avg loss: 2.452929, avg accuracy: 0.501045\n",
      "epoch: 11, avg loss: 1.891984, avg accuracy: 0.604314\n",
      "epoch: 12, avg loss: 1.380336, avg accuracy: 0.710687\n",
      "epoch: 13, avg loss: 1.001658, avg accuracy: 0.793238\n",
      "epoch: 14, avg loss: 0.774593, avg accuracy: 0.842097\n",
      "epoch: 15, avg loss: 0.455521, avg accuracy: 0.927808\n",
      "epoch: 16, avg loss: 0.280280, avg accuracy: 0.970531\n",
      "epoch: 17, avg loss: 0.159310, avg accuracy: 1.002358\n",
      "epoch: 18, avg loss: 0.139503, avg accuracy: 1.003681\n",
      "epoch: 19, avg loss: 0.063580, avg accuracy: 1.019896\n",
      "epoch: 20, avg loss: 0.025332, avg accuracy: 1.023528\n"
     ]
    }
   ],
   "source": [
    "for i in range(epoch):\n",
    "    total_loss, total_accuracy = 0, 0\n",
    "    X, Y = shuffle(X, Y)\n",
    "    for k in range(0, len(text_to), batch_size):\n",
    "        index = min(k + batch_size, len(text_to))\n",
    "        batch_x, batch_y = pad_sentence_batch(X[k: index], Y[k: index], PAD)\n",
    "        predicted, accuracy, loss, _ = sess.run([tf.argmax(model.logits,2),\n",
    "                                      model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += accuracy\n",
    "    total_loss /= (len(text_to) / batch_size)\n",
    "    total_accuracy /= (len(text_to) / batch_size)\n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 42)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = sess.run(model.predicting_ids, \n",
    "                     feed_dict={model.X:batch_x,model.Y:batch_y})\n",
    "predicted.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "row 1\n",
      "QUESTION: no , no , no . it &apos;s four minutes long .\n",
      "REAL ANSWER: không không không . nó chỉ dài bốn phút thôi .\n",
      "PREDICTED ANSWER: không chỉ chỉ không được chỉ dài bốn phút mang đó pháp sử đi nếu ấy ngày hợp thao qua qua xấu hoặc trung tượng đi qua kiếm xấu qua rằng vấn tàu vấn rằng rằng thực thực đi chút tàu \n",
      "\n",
      "row 2\n",
      "QUESTION: and you know , it was not a failure for ourselves in itself , but it was a failure that will impact his full life .\n",
      "REAL ANSWER: và bạn biết đó , nó không phải là thất bại với chính bản thân chúng tôi mà là thất bại sẽ ảnh hưởng đến suốt đời mario .\n",
      "PREDICTED ANSWER: khi nếu muốn đó thuật việc đề được tìm thất bại khí ý đồng đen đó toàn mà cũng thất tác chọn ảnh hưởng đến suốt đời &quot; qua nói &quot; ấy tối phê trung xa qua tài xấu giá lý \n",
      "\n",
      "row 3\n",
      "QUESTION: but the choice was theirs , and our audience quickly grew to choose the richest and most varied diet that we could provide .\n",
      "REAL ANSWER: nhưng sự lựa chọn thuộc về chúng , và khán thính giả của chúng tôi tăng lên nhanh chóng để chọn những món &quot; ăn kiêng &quot; giàu nhất và đa dạng nhất mà chúng tôi có thể cung cấp .\n",
      "PREDICTED ANSWER: nhưng chiếc lựa chọn thuộc về chúng dành nên khán thính sát nhật nói 1 tăng lên tuần chóng nên chọn lên món &quot; mạo truyền hoá xã đến qua đem luật yêu rwanda xa rằng giành xấu giá giá &quot; \n",
      "\n",
      "row 4\n",
      "QUESTION: a freshly waxed car , the water molecules slump to about 90 degrees .\n",
      "REAL ANSWER: một xe vừa bôi sáp , những phân tử nước sụt xuống gần ̣ 90 độ .\n",
      "PREDICTED ANSWER: từ xe vừa bôi sáp xét từ khô tử nước sụt xuống gần ̣ trị độ hoặc nguồn đến nguồn đến dụ qua đi qua yêu trung trở trung xuôi tàu qua xấu vấn tàu xấu qua thực hiểu dụng &quot; \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(batch_x)):\n",
    "    print('row %d'%(i+1))\n",
    "    print('QUESTION:',' '.join([rev_dictionary_from[n] for n in batch_x[i] if n not in [0,1,2,3]]))\n",
    "    print('REAL ANSWER:',' '.join([rev_dictionary_to[n] for n in batch_y[i] if n not in[0,1,2,3]]))\n",
    "    print('PREDICTED ANSWER:',' '.join([rev_dictionary_to[n] for n in predicted[i] if n not in[0,1,2,3]]),'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
